/*
 * Copyright (c) Hisilicon Technologies Co., Ltd. 2022-2022. All rights reserved.
 * Description: legacy quantize sever
 */

#include "omg/quantize_optimizer/legacy_quantize_saver.h"

#include <list>
#include <algorithm>

#include "graph/types.h"

#include "infra/base/assertion.h"

#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_walker.h"

#include "omg/quantize_optimizer/quantize_util.h"

using namespace std;
using namespace ge;

namespace hiai {
namespace {
QuantizeType GetQuantType(ge::DataType inputDataType, ge::DataType weightDataType)
{
    std::map<std::pair<ge::DataType, ge::DataType>, QuantizeType> quantTypeMaps = {
        { { DT_UINT8, DT_INT8 }, UINT8_INT8_QUANTIZED },
        { { DT_UINT8, DT_2BIT }, UINT8_INT2_QUANTIZED },
        { { DT_INT4, DT_INT4 }, INT4_INT4_QUANTIZED },
        { { DT_FLOAT16, DT_INT8 }, INT8_FILTER_QUANTIZED },
        { { DT_INT16, DT_INT8 }, INT16_INT8_QUANTIZED },
    };
    std::map<std::pair<ge::DataType, ge::DataType>, QuantizeType>::const_iterator it =
        quantTypeMaps.find({inputDataType, weightDataType});
    if (it == quantTypeMaps.cend()) {
        FMK_LOGE("quant input data type:%d and weight data type:%d not supported.", inputDataType, weightDataType);
        return RESERVE_QUANTIZED;
    }
    return it->second;
}

QuantizeAlgorithm GetQuantizeAlgo(ge::DataType inputDataType)
{
    std::list<ge::DataType> signedDatatypes = {ge::DT_INT16, ge::DT_INT8, ge::DT_INT4};
    if (std::find(signedDatatypes.begin(), signedDatatypes.end(), inputDataType) != signedDatatypes.end()) {
        return NON_OFFSET_ALGO;
    }

    return HALF_OFFSET_ALGO;
}

Status SetQuantizeInfos(const QuantizeConfig& quantizeConfig, Node* node)
{
    // 保存QuantizeInfo
    QuantizeInfo quantInfo;
    quantInfo.scale_data_value = quantizeConfig.inputScale[0];
    quantInfo.offset_data_value = quantizeConfig.inputOffset[0];

    quantInfo.set_scale_weight_value(
        quantizeConfig.weightScale.data(), quantizeConfig.weightScale.size() * sizeof(float));
    quantInfo.set_offset_weight_value(
        quantizeConfig.weightOffset.data(), quantizeConfig.weightOffset.size() * sizeof(float));
    quantInfo.scale_weight_mode = (quantizeConfig.weightScale.size() > 1) ? VECTOR_SCALE : SCALAR_SCALE;
    quantInfo.quantize_algo = static_cast<uint32_t>(GetQuantizeAlgo(quantizeConfig.inputDataType));

    OpDesc& opDesc = node->ROLE(NodeSpec).OpDesc();
    HIAI_EXPECT_EXEC(QuantizeUtil::SetQuantizeInfo(opDesc, quantInfo));

    // 设置量化类型
    QuantizeType quantType = GetQuantType(quantizeConfig.inputDataType, quantizeConfig.weightDataType);
    if (quantType == QuantizeType::RESERVE_QUANTIZED) {
        FMK_LOGE("Get quant type fail. node:%s.", opDesc.GetName().c_str());
        return hiai::FAILURE;
    }

    return QuantizeUtil::SetQuantType(opDesc, static_cast<int64_t>(quantType));
}

Status CheckSharedWeightQuantParams(const ge::Node& sharedNode, const QuantizeConfig& quantizeConfig)
{
    OpDesc& sharedOpDesc = sharedNode.ROLE(NodeSpec).OpDesc();
    QuantizeInfo quantInfo;
    if (QuantizeUtil::GetQuantizeInfo(sharedOpDesc, quantInfo) == hiai::SUCCESS) {
        HIAI_EXPECT_TRUE(quantInfo.scale_weight_value.GetSize() == quantizeConfig.weightScale.size() * sizeof(float));
        const float* currScales = reinterpret_cast<const float*>(quantInfo.scale_weight_value.GetData());
        HIAI_EXPECT_NOT_NULL(currScales);
        for (size_t i = 0; i < quantizeConfig.weightScale.size(); i++) {
            if (fabs(currScales[i] - quantizeConfig.weightScale[i]) > ZERO_EPS) {
                FMK_LOGE("Shared weight has different weight scales, op:%s.", sharedOpDesc.GetName().c_str());
                return hiai::FAILURE;
            }
        }
    }

    return hiai::SUCCESS;
}

bool CheckWeightQuantized(const ge::Node* filterNode)
{
    ge::TensorPtr filter = QuantizeUtil::GetFilterTensor(filterNode);
    HIAI_EXPECT_NOT_NULL(filter);

    if (filter->GetTensorDesc().GetDataType() == ge::DT_INT8 || filter->GetTensorDesc().GetDataType() == ge::DT_INT4 ||
        filter->GetTensorDesc().GetDataType() == ge::DT_2BIT) {
        return true;
    }
    return false;
}

Status SaveQuantizeInfo(const QuantizeConfig& quantizeConfig, Node* node, int64_t weightDataAddr = 0)
{
    ge::Node* srcNode = node->ROLE(NodeWalker).InDataNode(0); // 老版本量化配置支持的算子fm输入索引均为0
    HIAI_EXPECT_NOT_NULL(srcNode);
    if (srcNode->ROLE(NodeSpec).IsDataOp()) {
        ge::DataType dataInputType = srcNode->ROLE(NodeSpec).OpDesc().GetInputDesc(0).GetDataType();
        HIAI_EXPECT_TRUE(dataInputType == ge::DT_FLOAT || dataInputType == ge::DT_FLOAT16);
    }
    // 对于权值共享场景，需校验共享的目的node的量化参数是否一致
    Node* filterNode = QuantizeUtil::FindQuantizeWeightNode(node);
    HIAI_EXPECT_NOT_NULL_R(filterNode, hiai::PARAM_INVALID);
    if (filterNode->ROLE(NodeSpec).OutEdgeSize() > 1) {
        HIAI_EXPECT_EXEC(filterNode->ROLE(NodeWalker).ListOutNodes([&](ge::Node& sharedNode) {
            return CheckSharedWeightQuantParams(sharedNode, quantizeConfig);
        }));
    }

    // 保存量化参数
    HIAI_EXPECT_EXEC(SetQuantizeInfos(quantizeConfig, node));

    if (!CheckWeightQuantized(filterNode)) {
        bool isOneSideQuantize =
        (quantizeConfig.inputDataType == DT_FLOAT16) && (quantizeConfig.weightDataType == DT_INT8);
        return QuantizeUtil::QuantizeWeight(
            *filterNode, quantizeConfig.weightDataType, quantizeConfig.weightScale, isOneSideQuantize, weightDataAddr);
    }
    return hiai::SUCCESS;
}

hiai::Status QuantInfoExtConvert(const string& quantInfoExt, QuantizeInfoExt& quantizeInfoExt)
{
    uint32_t winoFlag = 0;
    int32_t nValue = 0;
    HIAI_EXPECT_EXEC(GetValueFromString(quantInfoExt, "winoFlag", winoFlag));
    if (winoFlag == 1) {
        HIAI_EXPECT_EXEC(GetValueFromString(quantInfoExt, "nValue", nValue));
        quantizeInfoExt.winoFlag = winoFlag;
        quantizeInfoExt.nValue = nValue;
        quantizeInfoExt.biasOptmizeType = 0;
    }
    return hiai::SUCCESS;
}

const char* const OP_DESC_QUANT_INFO_EXT = "QuantizeInfoExt";
hiai::Status SetQuantizeInfoExt(ge::OpDesc& opDesc, const QuantizeInfoExt& quantizeInfoExt)
{
    AttrValue::NamedAttrs namedAttrs;

    namedAttrs.SetAttr("winoFlag", AttrValue::CreateFrom(static_cast<int64_t>(quantizeInfoExt.winoFlag)));
    namedAttrs.SetAttr("nValue", AttrValue::CreateFrom(static_cast<int64_t>(quantizeInfoExt.nValue)));
    namedAttrs.SetAttr("biasOptmizeType", AttrValue::CreateFrom(static_cast<int64_t>(quantizeInfoExt.biasOptmizeType)));
    auto attr = AttrValue::CreateFrom(namedAttrs);
    if (opDesc.SetAttr(OP_DESC_QUANT_INFO_EXT, attr) != ge::GRAPH_SUCCESS) {
        FMK_LOGE("Set attr QuantizeInfoExt fail. node:%s.", opDesc.GetName().c_str());
        return hiai::FAILURE;
    }
    return hiai::SUCCESS;
}

Status SaveQuantizeInfoExt(const QuantizeConfig& quantizeConfig, Node* node)
{
    if ((!quantizeConfig.hasQuantInfoExt) || (quantizeConfig.quantInfoExt == "")) {
        return hiai::SUCCESS;
    }
    OpDesc& opDesc = node->ROLE(NodeSpec).OpDesc();
    HIAI_EXPECT_EXEC(QuantizeUtil::SetQuantInfoExt(opDesc, quantizeConfig.quantInfoExt));

    QuantizeInfoExt quantizeInfoExt;
    HIAI_EXPECT_EXEC(QuantInfoExtConvert(quantizeConfig.quantInfoExt, quantizeInfoExt));
    if (quantizeInfoExt.winoFlag == 1) {
        HIAI_EXPECT_EXEC(SetQuantizeInfoExt(opDesc, quantizeInfoExt));
    } else if (quantizeInfoExt.winoFlag == 2) {
        std::string opType = "Wino" + opDesc.GetType();
        opDesc.SetType(opType);
    }
    return hiai::SUCCESS;
}
} // namespace

Status LegacyQuantizeSaver::SaveOpQuantParams(
    const QuantizeConfig& quantizeConfig, ge::Node* node, int64_t weightDataAddr)
{
    string nodeName = node->ROLE(NodeSpec).Name();
    if (SaveQuantizeInfo(quantizeConfig, node, weightDataAddr) != SUCCESS) {
        FMK_LOGE("Node: %s save quantize info failed.", nodeName.c_str());
        return hiai::FAILED;
    }
    if (SaveQuantizeInfoExt(quantizeConfig, node) != SUCCESS) {
        FMK_LOGE("Node: %s save quantize info ext failed.", nodeName.c_str());
        return hiai::FAILED;
    }
    return hiai::SUCCESS;
}
} // namespace hiai